import mindspore.nn as nn
import numpy as np
from mindspore import Tensor
from mindspore.ops import operations as P
from src.common import Conv2d, ConvBNReLU


class Fluff(nn.Cell):
    """
    paper: Fast Object Detection with Latticed Multi-Scale Feature Fusion, http://arxiv.org/abs/2011.02780
    """

    def __init__(self, in_channels, out_channels, reduction_ratio=4, num_levels=3, num_branches=4,
                 dilatation_ratios=None):
        super(Fluff, self).__init__()
        self.in_channels = in_channels
        self.mid_channels = in_channels // reduction_ratio
        self.out_channels = out_channels
        self.num_levels = num_levels
        self.num_branches = num_branches
        if dilatation_ratios is None:
            self.dilatation_ratios = [[1, 2, 3, 6], [1, 2, 5, 9], [1, 2, 3, 6]]
            assert num_levels <= 3 and num_branches <= 4
        else:
            self.dilatation_ratios = dilatation_ratios
            self.num_levels = len(self.dilatation_ratios)
            self.num_branches = len(self.dilatation_ratios[0])
        self.convs = []
        for i in range(self.num_levels):
            level_i = []
            for j in range(self.num_branches):
                d = self.dilatation_ratios[i][j]
                if i == 0:
                    modules = [Conv2d(in_channels, self.mid_channels, 1, 1, 0, bias=False),
                               nn.BatchNorm2d(self.mid_channels)]
                else:
                    modules = []
                modules.extend([
                    nn.ReLU(),
                    Conv2d(self.mid_channels, self.mid_channels, 3, 1, d, d, bias=False),
                    nn.BatchNorm2d(self.mid_channels),
                ])
                level_i.append(nn.SequentialCell(modules))
            self.convs.append(nn.CellList(level_i))
        self.convs = nn.CellList(self.convs)

        self.conv_concat = nn.SequentialCell(
            nn.ReLU(),
            Conv2d(self.mid_channels * num_levels * num_branches, self.out_channels, 1, 1, 0, bias=False),
            nn.BatchNorm2d(self.out_channels)
        )
        if self.in_channels != self.out_channels:
            self.shortcut = nn.SequentialCell(
                Conv2d(self.in_channels, self.out_channels, 1, 1, 0, bias=False),
                nn.BatchNorm2d(self.out_channels)
            )
        else:
            self.shortcut = None
        self.relu = nn.ReLU()
        self.cat = P.Concat(axis=1)
        self.add = P.TensorAdd()

    def construct(self, x):
        features = []
        xs = []
        for i in range(self.num_branches):
            xs.append(x)
        for i in range(self.num_levels):
            for j in range(self.num_branches):
                xs[j] = self.convs[i][j](xs[j])
                features.append(xs[j])
        cat_list = ()
        for t in features:
            cat_list = cat_list + (t,)
        y = self.cat(cat_list)
        y = self.conv_concat(y)
        if self.shortcut is not None:
            x = self.shortcut(x)
        x = self.relu(self.add(x, y))
        return x


def _test():
    import torch
    import mindspore.context as context
    from extension.blocks import Fluff as Fluff_torch
    from src import convert
    from mindspore.train.serialization import load_param_into_net
    context.set_context(device_target="GPU")

    x = np.random.rand(2, 64, 64, 64).astype(np.float32)
    net = Fluff(64, 64)
    y = net(Tensor(x))
    print(y.shape)
    net.compile_and_run(Tensor(x))

    net2 = Fluff_torch(64, 64)
    net2.eval()
    load_param_into_net(net, convert(net2.state_dict()))
    y2 = net2(torch.from_numpy(x))
    y = net(Tensor(x))
    print(np.abs(y.asnumpy() - y2.cpu().detach().numpy()).max())


if __name__ == '__main__':
    _test()
